import torch
import pandas as pd
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
import numpy as np

device = torch.device("cuda")

max_name = 20

def nametovec(name):
    lt = list(name)
    vec = [ord(i) for i in lt]
    if len(vec) < 20:
        vec = vec + [50]*(20-len(vec))
    return vec

class NameData(Dataset):
    def __init__(self,train=True):
        train_path = "data/names_train.csv.gz"
        test_path = "data/names_test.csv.gz"
        path = train_path if train else test_path
        self.data = pd.read_csv(path,compression="gzip",header=None)
        self.name_list = list(self.data[0])
        self.country_list = pd.unique(self.data[1])
        self.len = self.data.shape[0]

        self.country_dict = {}
        for i in self.country_list:
            self.country_dict[i] = len(self.country_dict)




    def __getitem__(self, item):
       return torch.tensor(nametovec(self.name_list[item])),self.country_dict[self.data[1][item]]

    def __len__(self):
        return self.len



class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=300,num_layers=3,bidirectional=True,hidden_size=10,batch_first=True)
        self.learn = torch.nn.Linear(20,18)
        self.embdding = torch.nn.Embedding(200,300)


    def forward(self,input):
        x = self.embdding(input)

        x ,(h_n,c_n)  = self.lstm(x)
        output_fw = h_n[-2, :, :]  # 正向最后一次的输出
        output_bw = h_n[-1, :, :]  # 反向最后一次的输出
        output = torch.cat([output_fw, output_bw], dim=-1)


        out = self.learn(output)
        # print(out.size())
        out = F.softmax(out,dim=-1)
        # print(out.size())
        return out

model = Model().to(device)
model.load_state_dict(torch.load("model/model.pkl"))
opt = torch.optim.Adam(model.parameters(),lr = 0.01)
opt.load_state_dict(torch.load("model/optimizer.pkl"))
# loss_fn = torch.nn.CrossEntropyLoss()
loss_fn = torch.nn.NLLLoss()

def train():
    name_data = NameData(train=True)
    load = DataLoader(name_data,batch_size=100,shuffle=True)
    # print(load)

    for idx,(input,lable) in enumerate(load):
        # print(input)
        input = input.to(device)
        lable = lable.to(device)
        out_pre = model(input)
        out_pre=torch.log(out_pre)
        # print(out_pre)
        loss = loss_fn(out_pre,lable)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if idx % 10 ==0:
            torch.save(model.state_dict(), "./model/model.pkl")
            torch.save(opt.state_dict(), "./model/optimizer.pkl")
            print(loss.item(),"===",idx)
            # print(idx)

    # print(loss)

# train()
# name_data = NameData()
# load = DataLoader(name_data,batch_size=2)
# for idx ,(i,j) in enumerate(load):
#     print(i)
#     print(j)


# for i in range(100):
#     train()

def test1():
    name_data = NameData(train=False)
    load = DataLoader(name_data, batch_size=10, shuffle=False)
    error = []

    for idx,(input,label) in enumerate(load):
        input = input.to(device)
        label = label.to(device)

        out= model(input)
        # print(out,"****************************")

        k = out.max(dim=-1)[-1]

        # cc = pd.read_csv("data/names_test.csv.gz",compression="gzip",header=None)
        # t1 = pd.unique(cc[1])

        cur_acc = k.eq(label).float().mean()
        error.append(cur_acc.item())
        if idx % 1 == 0:
            # print(np.mean(error))
            print(k)
            print(label)
            print()
            print()


for i in range(100):
    train()

# test1()
